{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# ONNX Conversion of the GPT-2 Small Model\n",
        "This notebook is a companion of chapter 4 of the \"Domain Specific LLMs in Action\" book, author Guglielmo Iozzia, [Manning Publications](https://www.manning.com/), 2025.  \n",
        "The code in this notebook is to introduce readers to the [ONNX](https://onnx.ai/) format and [ONNX Runtime](https://onnxruntime.ai/) on GPU with the [GPT-2 Small](https://huggingface.co/openai-community/gpt2) model. It requires hardware acceleration (GPU).  \n",
        "More details about the code can be found in the related book's chapter."
      ],
      "metadata": {
        "id": "rkGavbEb2IF_"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "*** **Update September 2025: the code in this notebook isn't anymore compatible with PyTorch 2.1 or later and the HF's Transformers releases that support the latest PyTorch. We need then to downgrade PyTorch and the Transformers packages.** ***"
      ],
      "metadata": {
        "id": "BOVmWe0Tdkjp"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install torch==2.0.1 transformers==4.31.0"
      ],
      "metadata": {
        "id": "huhP8taCdmVd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Install the missing requirements (only ONNX and the ONNX runtime for GPUs)."
      ],
      "metadata": {
        "id": "R9vCD6LM3iSj"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SdjR55D_p_Sy"
      },
      "outputs": [],
      "source": [
        "!pip install onnx onnxruntime-gpu"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Download the GPT-2 Small model from the Hugging Face Hub and load it into the GPU memory."
      ],
      "metadata": {
        "id": "DyNnEmjq3qBP"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from transformers import GPT2Tokenizer, AutoModelForCausalLM\n",
        "\n",
        "model_id = 'openai-community/gpt2'\n",
        "tokenizer = GPT2Tokenizer.from_pretrained(model_id)\n",
        "model = AutoModelForCausalLM.from_pretrained(model_id)\n",
        "device = torch.device(\"cuda\")\n",
        "model.eval().to(device)"
      ],
      "metadata": {
        "id": "-lED6QTrqe1A"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Verify that the downloaded model works as expected."
      ],
      "metadata": {
        "id": "BpdAcl2T4Ryr"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "inputs = tokenizer(\"The story so far: in the beginning, the universe was created.\", return_attention_mask=False, return_tensors=\"pt\")\n",
        "print(\"input tensors\")\n",
        "print(inputs.to(device))\n",
        "print(\"input tensor shape\")\n",
        "print(inputs[\"input_ids\"].size())\n",
        "\n",
        "with torch.no_grad():\n",
        "    outputs = model(**inputs)\n",
        "\n",
        "logits = outputs.logits\n",
        "print(\"output tensor\")\n",
        "print(logits)\n",
        "print(\"output shape\")\n",
        "print(logits.shape)"
      ],
      "metadata": {
        "id": "wodI_-_8rKk4"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Create a directory where to store the ONNX converted model."
      ],
      "metadata": {
        "id": "DfFeZ9eQ5B1J"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "\n",
        "output_dir = os.path.join(\".\", \"onnx_models\")\n",
        "if not os.path.exists(output_dir):\n",
        "    os.makedirs(output_dir)\n",
        "export_model_path = os.path.join(output_dir, 'gpt-2.onnx')"
      ],
      "metadata": {
        "id": "iBKjgK1zsGd3"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Create an input tensor to be used for model conversion."
      ],
      "metadata": {
        "id": "KE3NFjXi5HZl"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "tokenized_inputs = tokenizer(\"The story so far: in the beginning, the universe was created.\",\n",
        "                             return_attention_mask=False,\n",
        "                             return_tensors=\"pt\")\n",
        "tokenized_inputs.to(device)\n",
        "inputs_sample = {\n",
        "        'input_ids':  tokenized_inputs['input_ids']\n",
        "    }"
      ],
      "metadata": {
        "id": "BJjO4XHnsJzP"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Convert the model to ONNX."
      ],
      "metadata": {
        "id": "tJfo_LC_6XNi"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "with torch.no_grad():\n",
        "  torch.onnx.export(model,\n",
        "                    inputs_sample,\n",
        "                    export_model_path,\n",
        "                    export_params=True,\n",
        "                    opset_version=15,\n",
        "                    do_constant_folding=True,\n",
        "                    input_names=['input_ids']\n",
        "                    )"
      ],
      "metadata": {
        "id": "QJGVh3L9yDbj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Optimize the exported model."
      ],
      "metadata": {
        "id": "37n1P7ys6cmx"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from onnxruntime.transformers import optimizer\n",
        "\n",
        "optimized_model_path = os.path.join(output_dir, 'gpt-2-onnx_opt_gpu.onnx')\n",
        "optimized_model = optimizer.optimize_model(export_model_path,\n",
        "                                           model_type='gpt2',\n",
        "                                           use_gpu=True,\n",
        "                                           num_heads=12,\n",
        "                                           hidden_size=768,\n",
        "                                           verbose=True)\n",
        "optimized_model.save_model_to_file(optimized_model_path)"
      ],
      "metadata": {
        "id": "Uu_cSLazzo8D"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Benchmark inference with the original model."
      ],
      "metadata": {
        "id": "afqJ0uQy2N42"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import time\n",
        "\n",
        "with torch.inference_mode():\n",
        "    sample_output = model.generate(inputs.input_ids, max_length=64, pad_token_id=50256)\n",
        "    print(tokenizer.decode(sample_output[0], skip_special_tokens=False))\n",
        "    for _ in range(2):\n",
        "        _ = model.generate(inputs.input_ids, max_length=64, pad_token_id=50256)\n",
        "        torch.cuda.synchronize()\n",
        "    start = time.time()\n",
        "    for _ in range(10):\n",
        "        _ = model.generate(inputs.input_ids, max_length=256, pad_token_id=50256)\n",
        "        torch.cuda.synchronize()\n",
        "    print(f\"----\\nPytorch: {(time.time() - start)/10:.2f}s/sequence\")\n",
        "_ = model.cpu()"
      ],
      "metadata": {
        "id": "dVwWlsz-2SgY"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Benchmark inference with the ONNX converted model."
      ],
      "metadata": {
        "id": "OK8EbSUO4dxC"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import onnxruntime\n",
        "import numpy\n",
        "\n",
        "session = onnxruntime.InferenceSession(export_model_path, providers=[\"CUDAExecutionProvider\"])\n",
        "onnx_input_ids = tokenizer(\"The story so far: in the beginning, the universe was created.\",\n",
        "                           return_attention_mask=False,\n",
        "                           return_tensors=\"np\")\n",
        "ort_inputs = {\n",
        "    \"input_ids\": onnx_input_ids['input_ids']\n",
        "}\n",
        "\n",
        "for _ in range(2):\n",
        "  ort_outputs = session.run(None, ort_inputs)\n",
        "start = time.time()\n",
        "for _ in range(10):\n",
        "  ort_outputs = session.run(None, ort_inputs)\n",
        "print(f\"----\\nPytorch: {(time.time() - start)/10:.2f}s/sequence\")"
      ],
      "metadata": {
        "id": "qf1UfaNW4uJ6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Benchmark inference with the optimized ONNX model."
      ],
      "metadata": {
        "id": "kYCPk9vq4umZ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import onnxruntime\n",
        "import numpy\n",
        "\n",
        "opt_session = onnxruntime.InferenceSession(optimized_model_path, providers=[\"CUDAExecutionProvider\"])\n",
        "onnx_input_ids = tokenizer(\"The story so far: in the beginning, the universe was created.\",\n",
        "                           return_attention_mask=False,\n",
        "                           return_tensors=\"np\")\n",
        "ort_inputs = {\n",
        "    \"input_ids\": onnx_input_ids['input_ids']\n",
        "}\n",
        "\n",
        "for _ in range(2):\n",
        "  ort_outputs = opt_session.run(None, ort_inputs)\n",
        "start = time.time()\n",
        "for _ in range(10):\n",
        "  ort_outputs = opt_session.run(None, ort_inputs)\n",
        "print(f\"----\\nPytorch: {(time.time() - start)/10:.2f}s/sequence\")"
      ],
      "metadata": {
        "id": "PRBmsHaU1GPK"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}